1 Prompt

You have been provided actual data from the covid-19 pandemic, as reported by state and national governments. These data have already had spikes or errors in reporting corrected and redistributed. Attached are data from the US (excluding Washington state):

  • Cumulative confirmed (reported) cases of Covid 19, by location by date
  • Cumulative reported deaths of covid-19, by location by date
  • Cumulative hospitalizations for Covid-19, by location by date (NOTE: hospitalization data is not available for all 50 states, only some)
  • Population of each location.

These are the only 4 inputs that IHME uses for our first-stage deaths model, which produces 14-day forecasts that are then used in an SEIR model.

2 Set up, including file setup and loading data.

2.1 We also check the dimensions of the data and what each variable might mean

library(here)
library(tidyverse)

setwd(here::here())

data = read.csv("covid_data_cases_deaths_hosp.csv")

dim(data) #7914 observations, 9 variables


length(unique(data$location_id)) #checking that there are 51 state IDs 
length(unique(data$Province.State)) #there are 51 unique state names

length(unique(data$Date)) #there are 194 unique dates 

min(data$Date) #43856 -- converts to January 28 2020
max(data$Date) #44049 -- converts to Aug 8 2020

length(unique(data$US.STATE)) #US.STATE doesn't appear to mean anything as it is just 1 

3 Data cleaning

3.1 Creating a date variable

  • Cleaning the date up:
    • One of the first things I noticed was that the date was given in a 5 number format, which is likely days from a date of origin. I took the date of origin to be January 1 1900, which made sense to me, given the earliest date in the dataset corresponds to January 28 2020, which marked the first case of COVID-19 in the United States. The latest date in the dataset corresponds to August 8 2020.
  • Making sure that there only zeros or more in the data
    • I also wanted to make sure that there were no negative cases in the dataset, which can sometimes happen.
  • Creating cases, deaths and hospitalizations per 100,000 to make data comparable across states
data = data %>%
  mutate(
    newdate = as.Date(Date, origin="1900-01-01"),
    cases = if_else(Confirmed < 0, 0, Confirmed),
    deaths = if_else(Deaths < 0, 0, Deaths),
    hospitalizations = if_else(Hospitalizations < 0,0, Hospitalizations),
    cases100k = (cases/population)*100000,
    deaths100k = (deaths/population)*100000,
    hosp100k = (hospitalizations/population)*100000)

4 Q1. What is the relationship between cases, hospitalizations, and deaths? Describe how these indicators relate to each other, and visualize the relationships in at least 2 different ways.

4.1 Summary of cases, deaths, hospitalizations by state

This table is arranged by most number of cases per 100k to least

summarytable = data %>%
  group_by(Province.State) %>%
  summarise(median_cases = median(cases100k, na.rm=TRUE),
            median_deaths = median(deaths100k, na.rm=TRUE),
            median_hosp = median(hosp100k, na.rm=TRUE)) %>%
    arrange(desc(median_cases)) 


summarytable

4.2 Relationship between cases and deaths.

We expect there to be a pretty substantial correlation – the higher the case count, the higher the number of deaths in each state. We find that there is an overall correlation coefficient of R = 0.82 between cases per 100,000 and deaths per 100,000. We also provide a correlation between cases and deaths by state in table and figure form (though not all states have correlation coefficients due to incomplete data)

library(viridis)
## Loading required package: viridisLite
ggplot(data=data) + 
  geom_point(aes(x=cases100k, y=deaths100k), size=0.5) +
  xlab('Cases/100k population') + 
  ylab ('Deaths/100k population') +
  facet_wrap(~Province.State)+
  theme_bw()

#correlation coefficient by state
groupcorr = data %>%
  group_by(Province.State) %>%
  summarise(correlation = cor(cases100k, deaths100k)) %>%
  filter(!is.na(correlation))

groupcorr
ggplot(groupcorr, aes(x = reorder(Province.State, correlation), y = correlation)) +
  geom_bar(stat = "identity", fill = "steelblue") +
  coord_flip() +  
  labs(title = "Correlation between Cases and Deaths per 100k by State",
       x = "State",
       y = "Correlation Coefficient") +
  theme_minimal()

#overall correlation coefficient 
cor(data$cases100k, data$deaths100k, use = "complete.obs",
    method = c("pearson", "kendall", "spearman")) #0.82
## [1] 0.8214892

4.3 Relationship between deaths and hospitalizations

For hospitalizations, we will subset the data to rows that only contain hospitalization data. This way, we will avoid plotting blank points.

We see that there is an even higher correlation between deaths and hospitalizations when broken down by state, which is expected. The overall coefficient between deaths and hospitalizations is 0.93.

hospdata = data %>%
    filter(!is.na(hospitalizations))


ggplot(data=hospdata) + 
  geom_point(aes(x=hosp100k, y=deaths100k), size=0.5) +
  xlab('Hospitalizations/100k population') + 
  ylab ('Deaths/100k population') +
  facet_wrap(~Province.State)+
  theme_bw()

groupcorr2= hospdata %>%
  group_by(Province.State) %>%
  summarise(correlation = cor(hosp100k, deaths100k,  use="complete.obs")) %>%
  filter(!is.na(correlation))

groupcorr2
ggplot(groupcorr2, aes(x = reorder(Province.State, correlation), y = correlation)) +
  geom_bar(stat = "identity", fill = "lightgreen") +
  coord_flip() +  
  labs(title = "Correlation between Hospitalizations and Deaths per 100k by State",
       x = "State",
       y = "Correlation coefficient") +
  theme_minimal()

cor(hospdata$hosp100k, hospdata$deaths100k, use = "complete.obs",
    method = c("pearson", "kendall", "spearman")) #0.93
## [1] 0.9253078

4.4 Spatial correlation

Here we plot a map showing how cases and deaths vary spatially. For this example, I have only selected the cumulative case and death count on the last recorded day (8/8/2020) since the map is 1 dimensional in time. I have categorized the cumulative case adn death counts into a 9 level bivariate category to illustrate the relationship between COVID-19 cases and deaths by state: low case/low death count, low case/med death count, low case/high death, med case/low death, and so on and so forth.

The map below shows how cases and deaths are correlated throughout the country on 8/8/2020, noting that some states in the South (Louisiana, North Carolina, South Carolina) and East (New York, Massachusetts) having a high case and death count.

#here we filter out the relevant variables to make the dataset more manageable
map_data = data %>%
  select(Province.State, newdate, cases100k, deaths100k, hosp100k) %>%
  filter(newdate == max(newdate))

#creates a dataset containing case and death count data
statemap_join = left_join(map_data, statemap, by="Province.State")


#bivariate map file
casedeathmap = bi_class(statemap_join, x = cases100k, y = deaths100k, style="quantile") %>%
   slice(-1)
casedeathmap = st_as_sf(casedeathmap)

color_palette = RColorBrewer::brewer.pal(9, "YlOrRd")

m5 = tm_shape(casedeathmap, projection = 2163, unit = "mi") +
  tm_polygons("bi_class", palette = RColorBrewer::brewer.pal(9, "YlOrRd"), 
             border.col = "white", lwd = 0.1, legend.show=FALSE) +
  tm_shape(statemap) + 
  tm_borders(col="black", lwd=0.3, lty="solid")+
  tm_text("Province.State", size=0.6, root=3, remove.overlap=TRUE)+
  tm_layout(
    outer.margins = 0,  
    asp = 0,
       legend.width=1,
    legend.height=0.5
    )+
    tm_add_legend(
    type = "fill",
    col = color_palette,
    title = "COVID Cases and Deaths per 100k",
    is.portrait = TRUE,
    labels = c("Low Cases/Low Deaths", "Low Cases/Med Deaths", "Low Cases/High Deaths",
               "Med Cases/Low Deaths", "Med Cases/Med Deaths", "Med Cases/High Deaths",
               "High Cases/Low Deaths", "High Cases/Med Deaths", "High Cases/High Deaths")
 
  )



m5

5 Q2. Fit a curve of daily deaths, utilizing these inputs. Describe the approach you used and visualize the results.

5.1 Checking distribution of data

  • We can see that the distribution of confirmed cases, deaths and hospitalizations is right skewed.
  • The variance is greater than the mean for cases, deaths and hospitalizations.
hist(data$cases)

hist(data$deaths)

hist(data$hospitalizations)

#check if variance is greater than mean for cases
with(data, tapply(cases, Province.State, function(x) {
    sprintf("M (var) = %1.2f (%1.2f)", mean(x,na.rm=TRUE), var(x, na.rm=TRUE))
}))
##                                                                       Alabama 
##           "M (var) = 133.73 (1384.49)"    "M (var) = 25704.53 (742795215.61)" 
##                                 Alaska                                Arizona 
##         "M (var) = 774.62 (712484.96)"   "M (var) = 36259.52 (2990380977.45)" 
##                               Arkansas                             California 
##    "M (var) = 12343.89 (187082538.76)" "M (var) = 117010.72 (22632032115.99)" 
##                               Colorado                            Connecticut 
##    "M (var) = 21612.76 (218786478.31)"    "M (var) = 30973.12 (338623054.84)" 
##                               Delaware                   District of Columbia 
##      "M (var) = 7412.47 (25597794.30)"      "M (var) = 6582.37 (18420990.17)" 
##                                Florida                                Georgia 
## "M (var) = 110458.66 (20549808845.26)"   "M (var) = 55826.97 (3117505325.58)" 
##                                 Hawaii                                  Idaho 
##         "M (var) = 754.18 (319784.83)"      "M (var) = 5142.22 (37561901.47)" 
##                               Illinois                                Indiana 
##   "M (var) = 91234.97 (3835241282.45)"    "M (var) = 28995.28 (455362912.32)" 
##                                   Iowa                                 Kansas 
##    "M (var) = 17777.07 (222814033.18)"      "M (var) = 9651.47 (72514969.03)" 
##                               Kentucky                              Louisiana 
##     "M (var) = 10301.81 (86002307.43)"   "M (var) = 42288.85 (1099106001.76)" 
##                                  Maine                               Maryland 
##       "M (var) = 1950.57 (1789325.31)"    "M (var) = 43083.18 (888300841.46)" 
##                          Massachusetts                               Michigan 
##   "M (var) = 59075.84 (2332863636.98)"    "M (var) = 51031.21 (778118807.94)" 
##                              Minnesota                            Mississippi 
##    "M (var) = 20617.55 (349446084.07)"    "M (var) = 18080.76 (321366130.14)" 
##                               Missouri                                Montana 
##    "M (var) = 15424.34 (200879113.81)"        "M (var) = 955.60 (1243847.65)" 
##                               Nebraska                                 Nevada 
##     "M (var) = 11180.60 (84232938.20)"    "M (var) = 12921.20 (208118604.42)" 
##                          New Hampshire                             New Jersey 
##       "M (var) = 3452.22 (5750607.37)"  "M (var) = 118239.52 (4354790214.04)" 
##                             New Mexico                               New York 
##      "M (var) = 7769.04 (42548261.44)" "M (var) = 280235.03 (22915287915.56)" 
##                         North Carolina                           North Dakota 
##   "M (var) = 37117.99 (1560810914.28)"       "M (var) = 2369.15 (4011375.43)" 
##                                   Ohio                               Oklahoma 
##    "M (var) = 34658.71 (786174403.82)"     "M (var) = 9696.54 (115924395.88)" 
##                                 Oregon                           Pennsylvania 
##      "M (var) = 5452.09 (31697670.62)"   "M (var) = 60222.86 (1496539414.07)" 
##                           Rhode Island                         South Carolina 
##      "M (var) = 9471.77 (50947092.84)"    "M (var) = 23304.10 (797856878.30)" 
##                           South Dakota                              Tennessee 
##       "M (var) = 4188.93 (8936234.68)"   "M (var) = 29886.66 (1001545684.73)" 
##                                  Texas                                   Utah 
##  "M (var) = 97499.42 (17596205321.66)"    "M (var) = 13108.60 (169495428.35)" 
##                                Vermont                               Virginia 
##         "M (var) = 889.22 (181333.38)"    "M (var) = 37562.43 (927534909.98)" 
##                          West Virginia                              Wisconsin 
##       "M (var) = 2184.65 (3730471.36)"    "M (var) = 15173.51 (274313734.01)" 
##                                Wyoming 
##        "M (var) = 1000.89 (643525.29)"
#check if variance is greater than mean for deaths
with(data, tapply(deaths, Province.State, function(x) {
    sprintf("M (var) = %1.2f (%1.2f)", mean(x, na.rm=TRUE), var(x, na.rm=TRUE))
}))
##                                                                 Alabama 
##             "M (var) = 2.25 (1.64)"      "M (var) = 616.07 (248109.11)" 
##                              Alaska                             Arizona 
##           "M (var) = 10.77 (38.12)"     "M (var) = 877.88 (1195310.29)" 
##                            Arkansas                          California 
##       "M (var) = 158.68 (20685.71)"    "M (var) = 3057.39 (9334585.18)" 
##                            Colorado                         Connecticut 
##     "M (var) = 1043.47 (475901.05)"    "M (var) = 2737.37 (3103878.21)" 
##                            Delaware                District of Columbia 
##       "M (var) = 333.79 (53618.15)"       "M (var) = 333.37 (52442.10)" 
##                             Florida                             Georgia 
##    "M (var) = 2293.38 (4101248.45)"    "M (var) = 1683.50 (1550561.07)" 
##                              Hawaii                               Idaho 
##           "M (var) = 14.07 (67.71)"         "M (var) = 70.68 (2619.70)" 
##                            Illinois                             Indiana 
##    "M (var) = 4112.33 (8131003.40)"    "M (var) = 1591.56 (1225143.99)" 
##                                Iowa                              Kansas 
##      "M (var) = 408.90 (103383.62)"       "M (var) = 178.66 (13483.91)" 
##                            Kentucky                           Louisiana 
##       "M (var) = 383.06 (57880.22)"    "M (var) = 2180.05 (1682368.90)" 
##                               Maine                            Maryland 
##         "M (var) = 67.95 (1889.84)"    "M (var) = 2029.09 (1641686.33)" 
##                       Massachusetts                            Michigan 
##   "M (var) = 4011.56 (12895498.79)"    "M (var) = 4223.20 (5733952.34)" 
##                           Minnesota                         Mississippi 
##      "M (var) = 798.19 (411078.89)"      "M (var) = 648.05 (294810.52)" 
##                            Missouri                             Montana 
##      "M (var) = 610.76 (195501.74)"          "M (var) = 19.09 (236.90)" 
##                            Nebraska                              Nevada 
##       "M (var) = 158.45 (13421.83)"       "M (var) = 345.92 (63064.48)" 
##                       New Hampshire                          New Jersey 
##       "M (var) = 192.53 (25998.42)"   "M (var) = 9599.58 (37868657.99)" 
##                          New Mexico                            New York 
##       "M (var) = 293.80 (52011.31)" "M (var) = 21416.60 (156954791.19)" 
##                      North Carolina                        North Dakota 
##      "M (var) = 769.47 (422457.19)"         "M (var) = 48.43 (1333.15)" 
##                                Ohio                            Oklahoma 
##    "M (var) = 1720.73 (1529599.19)"       "M (var) = 266.57 (28796.52)" 
##                              Oregon                        Pennsylvania 
##        "M (var) = 131.35 (9641.35)"    "M (var) = 3957.30 (7948772.71)" 
##                        Rhode Island                      South Carolina 
##      "M (var) = 452.47 (155225.49)"      "M (var) = 535.97 (250786.18)" 
##                        South Dakota                           Tennessee 
##         "M (var) = 53.88 (2039.21)"      "M (var) = 381.30 (105856.62)" 
##                               Texas                                Utah 
##    "M (var) = 1724.72 (3718210.82)"        "M (var) = 108.93 (9195.08)" 
##                             Vermont                            Virginia 
##          "M (var) = 42.56 (385.89)"     "M (var) = 1055.32 (620667.41)" 
##                       West Virginia                           Wisconsin 
##         "M (var) = 59.00 (1603.68)"      "M (var) = 392.94 (118066.40)" 
##                             Wyoming 
##           "M (var) = 12.02 (87.56)"
#check if variance is greater than mean for hospitalizations
with(data, tapply(hospitalizations, Province.State, function(x) {
    sprintf("M (var) = %1.2f (%1.2f)", mean(x, na.rm=TRUE), var(x, na.rm=TRUE))
}))
##                                                               Alabama 
##               "M (var) = NaN (NA)"   "M (var) = 4220.78 (6309089.98)" 
##                             Alaska                            Arizona 
##        "M (var) = 53.81 (1009.80)"          "M (var) = 10060.00 (NA)" 
##                           Arkansas                         California 
##    "M (var) = 1266.76 (562035.98)"               "M (var) = NaN (NA)" 
##                           Colorado                        Connecticut 
##   "M (var) = 3828.04 (4031948.94)"               "M (var) = NaN (NA)" 
##                           Delaware               District of Columbia 
##               "M (var) = NaN (NA)"               "M (var) = NaN (NA)" 
##                            Florida                            Georgia 
## "M (var) = 10328.16 (53739973.62)"  "M (var) = 8243.03 (24344251.09)" 
##                             Hawaii                              Idaho 
##        "M (var) = 89.88 (2282.84)"      "M (var) = 306.66 (41890.98)" 
##                           Illinois                            Indiana 
##               "M (var) = NaN (NA)"    "M (var) = 7316.12 (872382.42)" 
##                               Iowa                             Kansas 
##               "M (var) = NaN (NA)"     "M (var) = 744.15 (201000.59)" 
##                           Kentucky                          Louisiana 
##    "M (var) = 2946.13 (195215.09)"               "M (var) = NaN (NA)" 
##                              Maine                           Maryland 
##      "M (var) = 260.07 (10366.05)"  "M (var) = 7658.91 (16012307.53)" 
##                      Massachusetts                           Michigan 
##    "M (var) = 1687.30 (215771.79)"               "M (var) = NaN (NA)" 
##                          Minnesota                        Mississippi 
##   "M (var) = 2584.30 (2912611.07)"   "M (var) = 1842.72 (1175549.93)" 
##                           Missouri                            Montana 
##               "M (var) = NaN (NA)"        "M (var) = 88.02 (2773.30)" 
##                           Nebraska                             Nevada 
##     "M (var) = 1338.55 (52177.03)"               "M (var) = NaN (NA)" 
##                      New Hampshire                         New Jersey 
##      "M (var) = 448.72 (43888.19)"               "M (var) = NaN (NA)" 
##                         New Mexico                           New York 
##    "M (var) = 1496.10 (532600.29)"               "M (var) = NaN (NA)" 
##                     North Carolina                       North Dakota 
##               "M (var) = NaN (NA)"      "M (var) = 148.54 (11686.26)" 
##                               Ohio                           Oklahoma 
##   "M (var) = 5466.83 (9814958.57)"    "M (var) = 1147.75 (665100.00)" 
##                             Oregon                       Pennsylvania 
##     "M (var) = 738.48 (159758.80)"               "M (var) = NaN (NA)" 
##                       Rhode Island                     South Carolina 
##               "M (var) = NaN (NA)"               "M (var) = NaN (NA)" 
##                       South Dakota                          Tennessee 
##      "M (var) = 428.98 (73272.33)"   "M (var) = 1971.84 (1622444.79)" 
##                              Texas                               Utah 
##               "M (var) = NaN (NA)"     "M (var) = 992.18 (493498.92)" 
##                            Vermont                           Virginia 
##               "M (var) = NaN (NA)"   "M (var) = 4143.43 (6448855.13)" 
##                      West Virginia                          Wisconsin 
##               "M (var) = NaN (NA)"   "M (var) = 2553.50 (1478238.41)" 
##                            Wyoming 
##        "M (var) = 92.00 (1419.42)"

Based on the fact that variance >> mean for all of the variables:

Because of the skew in the distribution in the exposure (cases) and the fact that variance is greater than the mean, I will explore the use of a negative binomial regression to predict daily death count. This has been done in other studies also (https://www.nber.org/papers/w27391, https://jamanetwork.com/journals/jamanetworkopen/article-abstract/2779417) I take into account that there is an offset term given by the log of the population, and that daily deaths will vary by location_id/state.

I have plotted the actual vs predicted daily death count for 3 negative binomial regression models, with death as the outcome. In the first model, I use only cases as a predictor with an offset population term, while accounting for state level differences. The prediction varies by state.

In the second model, I use cases and hospitalizations as the predictor, while accounting for state level differences. The third model only uses cases as the predictor with an offset population term.

5.2 Negative binomial

library(pscl)
library(lme4)

#negative binomial regression
#simple model, only cases as the predictor
model1 = MASS::glm.nb(formula = deaths ~ cases + factor(location_id) + offset(log(population)), 
               data = data, 
               na.action = na.exclude)


#includes hospitalizations as an additional predictor
model2 = MASS:: glm.nb(formula = deaths ~ cases + hospitalizations + factor(location_id) + offset(log(population)), 
               data = data, 
               na.action = na.exclude)


#most parsimonious model, does not account for state 
model3 = MASS::glm.nb(formula = deaths ~ cases + offset(log(population)), 
               data = data, 
               na.action = na.exclude)


data$predicted_cases1 <- predict(model1, type = "response")
data$predicted_cases2 <- predict(model2, type = "response")
data$predicted_cases3 <- predict(model3, type = "response")


#actual no of cases (x) vs predicted (y)

ggplot(data, aes(x = cases, y = predicted_cases1, color=as.factor(Province.State))) +
 scale_color_viridis(discrete=TRUE, option = "D")+
  geom_point(alpha = 0.5) +
  labs(title = "Model 1: deaths ~ cases",
       x = "Actual Cases",
       y = "Predicted Cases") +
  theme_minimal() +
   theme(plot.title = element_text(color="black", size=10, face="bold.italic"))

ggplot(data, aes(x = cases, y = predicted_cases2, color=as.factor(Province.State))) +
 scale_color_viridis(discrete=TRUE, option = "D")+
  geom_point(alpha = 0.5) +
  labs(title = "Model 2: deaths ~ cases + hospitalizations",
       x = "Actual Cases",
       y = "Predicted Cases") +
  theme_minimal() +
   theme(plot.title = element_text(color="black", size=10, face="bold.italic"))

ggplot(data, aes(x = cases, y = predicted_cases3, color=as.factor(Province.State))) +
 scale_color_viridis(discrete=TRUE, option = "D")+
  geom_point(alpha = 0.5) +
  labs(title = "Model 3: deaths ~ cases (no state)",
       x = "Actual Cases",
       y = "Predicted Cases") +
  theme_minimal() +
   theme(plot.title = element_text(color="black", size=10, face="bold.italic"))

6 Q3. Create projections for 14-days after the last observed data point. Visualize the result. Describe the benefits and limitations of your approach. Where do you think this approach has performed particularly well? What types of situations cause your model to struggle?

I decided to use ARIMA to perform time-series forecasting due to several reasons. ARIMA is AutoRegressive Integrated Moving Average, specified by an autoregressive component that refers to the use of past values, difference of observations to make the time series stationary, and a moving average component that shows the error of the model as a combination of prior error term.

First, ARIMA is better suited to handle nonstationarity compared to other linear regression methods Second, ARIMA is easier to use than classical nonlinear epidemiological models such as SIR or SEIR for which no closed form exact solutions exists, and numerical simulations are required [Barlow NS, Weinstein SJ. Accurate closed-form solution of the SIR epidemic model. Physica D. 2020;408:132540. doi: 10.1016/j.physd.2020.132540.]. A recent study also showed that ARIMA models outperformed SIR in predicting COVID-19 cases [Abuhasel KA, Khadr M, Alquraish MM. Analyzing and forecasting COVID-19 pandemic in the Kingdom of Saudi Arabia using ARIMA and SIR models. Comput Intell. 2022;38:770–783. doi: 10.1111/coin.12407, Abolmaali S, Shirzaei S. A comparative study of SIR Model, Linear Regression, Logistic Function and ARIMA Model for forecasting COVID-19 cases. AIMS Public Health. 2021;8:598–613. doi: 10.3934/publichealth.2021048.].

Benefits: If the time series data is stationary, the ARIMA model works well and produces accurate forecasts.

Limitations: ARIMA also doesn’t work if the residuals are not normally distributed, and if the residuals are correlated. ARIMA also assumes constant mean and variance over time. ARIMA is very sensitive to missing data, as we have seen. ARIMA does not work on missing data.

library(forecast)
## Warning: package 'forecast' was built under R version 4.3.3
## Registered S3 method overwritten by 'quantmod':
##   method            from
##   as.zoo.data.frame zoo
library(dplyr)

#Create a state-specific 14 day forecast for cases, deaths, hospitalization
states =  c("Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", 
            "Connecticut", "Delaware", "Florida", "Georgia", "Hawaii", "Idaho", 
            "Illinois", "Indiana", "Iowa", "Kansas", "Kentucky", "Louisiana", 
            "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota", 
            "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", 
            "New Hampshire", "New Jersey", "New Mexico", "New York", "North Carolina", 
            "North Dakota", "Ohio", "Oklahoma", "Oregon", "Pennsylvania", 
            "Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", 
            "Utah", "Vermont", "Virginia", "West Virginia", 
            "Wisconsin", "Wyoming", "District of Columbia")

##CASES
forecast_by_state_cases <- function(data) {
  for (state in states) {
    # Filter data for the current state
    data_subset <- data %>%
      filter(Province.State == state) %>%
      select(cases)
    
    # Fit ARIMA model and forecast
    fit <- auto.arima(data_subset$cases)
    forecastedValues <- forecast(fit, 14)
    
    # Plot the forecast
    
    plot(forecastedValues, main = paste("Forecast for", state), 
        col.main = "darkgreen",
        xlab = "Time in days",
        ylab = "Cases")
  
  }
}

forecast_by_state_cases(data)

##DEATHS 
forecast_by_state_deaths <- function(data) {
  for (state in states) {
    # Filter data for the current state
    data_subset <- data %>%
      filter(Province.State == state) %>%
      select(deaths)
    
    # Fit ARIMA model and forecast
    fit <- auto.arima(data_subset$deaths)
    forecastedValues <- forecast(fit, 14)
    
    # Plot the forecast
    
    plot(forecastedValues, main = paste("Forecast for", state), 
        col.main = "purple",
        xlab = "Time in days",
        ylab = "Deaths")
  
  }
}

forecast_by_state_deaths(data)

##HOSPITALIZATIONS 

#subsetting data to contain states containing no missing hospitalization data
hosp_data = data %>%
  filter(complete.cases(hospitalizations))

state_hosp_name = c("Alabama", "Alaska","Arizona","Arkansas","Colorado", "Florida", "Georgia","Hawaii","Idaho", "Indiana","Kansas","Kentucky","Maine","Maryland", "Massachusetts", "Minnesota","Mississippi","Montana","Nebraska", "New Hampshire", "New Mexico" ,"North Dakota","Ohio", "Oklahoma","Oregon", "South Dakota","Tennessee","Utah","Virginia", "Wisconsin", "Wyoming")      

forecast_by_state_hosps <- function(data) {
  for (state in state_hosp_name) {
    # Filter data for the current state
    data_subset <- data %>%
        filter(Province.State == state) %>%
        select(hospitalizations)
    
    # Fit ARIMA model and forecast
    fit <- auto.arima(data_subset$hospitalizations)
    forecastedValues <- forecast(fit, 14)
    
    # Plot the forecast
    
    plot(forecastedValues, main = paste("Forecast for", state), 
        col.main = "darkblue",
        xlab = "Time in days",
        ylab = "Hospitalizations")
  
  }
}

forecast_by_state_hosps(hosp_data)

7 Q4. Lastly, describe future areas of exploration or improvement for your approach. If you had more time, what would you do next?

One assumption made is that historical data is used to make predictions, which may not be entirely true for epidemic modelling, as other factors such as new variants, changing climates, vaccination trends, and others can all affect trajectories. Inter-state dynamics or relationships between the different states could also impact these trajectories. If I had more time, I would look into other confounding factors and include them in a model to a make more accurate predictions.

Assuming we have access to other data sources, I would like to use recurrent neural networks to do forecasting. Currently we are doing univariate forecasting, neural networks (such as Transformers (https://arxiv.org/abs/1706.03762)) will allow multivariate forecasting. Another benefit of Transformers over the other architectures is that we can incorporate missing values (which are common in the time series setting). However, one drawback is that training these models require access to large datasets and are computationally expensive.